{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eed2e34c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional, Type, List, Tuple\n",
    "from langchain.pydantic_v1 import BaseModel, Field\n",
    "from langchain.tools import BaseTool\n",
    "from langchain_community.graphs import Neo4jGraph\n",
    "from langchain.agents import AgentExecutor\n",
    "from langchain.agents.format_scratchpad import format_to_openai_function_messages\n",
    "from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser\n",
    "from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
    "from langchain_core.messages import AIMessage, HumanMessage\n",
    "from langchain_core.utils.function_calling import convert_to_openai_function\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain.callbacks.manager import (\n",
    "    AsyncCallbackManagerForToolRun,\n",
    "    CallbackManagerForToolRun,\n",
    ")\n",
    "import openai\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96a3729b-aadc-41f9-9743-6941b6aaf83b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set openai api_key and base_url\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"openai_api_key\"\n",
    "openai.api_base = \"base_url\" "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a633aab1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set neo4j parameters\n",
    "os.environ[\"NEO4J_URI\"] = \"neo4j_url\"\n",
    "os.environ[\"NEO4J_USERNAME\"] = \"neo4j\"\n",
    "os.environ[\"NEO4J_PASSWORD\"] = \"password\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e16111be",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Call langchain to connect to neo4j\n",
    "graph = Neo4jGraph()\n",
    "graph.refresh_schema()\n",
    "print(graph.schema)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66c5b257",
   "metadata": {},
   "outputs": [],
   "source": [
    "description_query = \"\"\"\n",
    "MATCH (m:病名|症状|用药|治则|病机)\n",
    "WHERE m.title CONTAINS $candidate OR m.name CONTAINS $candidate\n",
    "MATCH (m)-[r:治疗|属于|运用|推荐用药|遵循|导致]-(t)\n",
    "WITH m, type(r) as type, collect(coalesce(t.name, t.title)) as names\n",
    "WITH m, type+\": \"+reduce(s=\"\", n IN names | s + n + \", \") as types\n",
    "WITH m, collect(types) as contexts\n",
    "WITH m, \"type:\" + labels(m)[0] + \"\\ntitle: \"+ coalesce(m.title, m.name)  +\"\\n\" +\n",
    "       reduce(s=\"\", c in contexts | s + substring(c, 0, size(c)-2) +\"\\n\") as context\n",
    "RETURN context LIMIT 1\n",
    "\"\"\"\n",
    "\n",
    "def get_information(entity: str) -> str:\n",
    "    try:\n",
    "        data = graph.query(description_query, params={\"candidate\": entity})\n",
    "        return data[0][\"context\"]\n",
    "    except IndexError:\n",
    "        return \"No information was found\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26ec383c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class InformationInput(BaseModel):\n",
    "    entity: str = Field(description=\"something about traditional Chinese medicine mentioned in the question\")\n",
    "\n",
    "\n",
    "class InformationTool(BaseTool):\n",
    "    name = \"Information\"\n",
    "    description = (\n",
    "        \"useful for when you need to answer questions about traditional Chinese medicinee\"\n",
    "    )\n",
    "    args_schema: Type[BaseModel] = InformationInput\n",
    "\n",
    "    def _run(\n",
    "        self,\n",
    "        entity: str,\n",
    "        run_manager: Optional[CallbackManagerForToolRun] = None,\n",
    "    ) -> str:\n",
    "        \"\"\"Use the tool.\"\"\"\n",
    "        return get_information(entity)\n",
    "\n",
    "    async def _arun(\n",
    "        self,\n",
    "        entity: str,\n",
    "        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,\n",
    "    ) -> str:\n",
    "        \"\"\"Use the tool asynchronously.\"\"\"\n",
    "        return get_information(entity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "080061f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0, openai_api_key=openai.api_key, openai_api_base=openai.api_base)\n",
    "tools = [InformationTool()]\n",
    "\n",
    "llm_with_tools = llm.bind(functions=[convert_to_openai_function(t) for t in tools])\n",
    "\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            \"You are a helpful assistant that finds information about traditional Chinese medicine \"\n",
    "            \" and recommends them. \"\n",
    "            \"Do only the things the user specifically requested. \"\n",
    "            \"Start with 'According to the medical cases of Wang Zhongqi' when answering, \"\n",
    "            \"end with 'The results show individual phenomena, and medical conditions vary from person to person. Please consider your own situation and always consult a professional doctor's advice for optimal care.'\",\n",
    "        ),\n",
    "        MessagesPlaceholder(variable_name=\"chat_history\"),\n",
    "        (\"user\", \"{input}\"),\n",
    "        MessagesPlaceholder(variable_name=\"agent_scratchpad\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "\n",
    "def _format_chat_history(chat_history: List[Tuple[str, str]]):\n",
    "    buffer = []\n",
    "    for human, ai in chat_history:\n",
    "        buffer.append(HumanMessage(content=human))\n",
    "        buffer.append(AIMessage(content=ai))\n",
    "    return buffer\n",
    "\n",
    "\n",
    "agent = (\n",
    "    {\n",
    "        \"input\": lambda x: x[\"input\"],\n",
    "        \"chat_history\": lambda x: _format_chat_history(x[\"chat_history\"])\n",
    "        if x.get(\"chat_history\")\n",
    "        else [],\n",
    "        \"agent_scratchpad\": lambda x: format_to_openai_function_messages(\n",
    "            x[\"intermediate_steps\"]\n",
    "        ),\n",
    "    }\n",
    "    | prompt\n",
    "    | llm_with_tools\n",
    "    | OpenAIFunctionsAgentOutputParser()\n",
    ")\n",
    "\n",
    "agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48f187aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example of running\n",
    "agent_executor.invoke({\"input\": \"外感湿邪可能导致什么症状？\"}) # Input qurey"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "QAM",
   "language": "python",
   "name": "qam"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
